# -*- coding: utf-8 -*-
"""
Created on Fri Sep  2 10:30:09 2022

@author: wulong
"""
import numpy as np
from casadi import *
from Fast_2_model import*

# In[1] Fast EMPC 2
def formulate_opt_f2(initial_state, state_guess, state_bnds_lo, state_bnds_up,
                     slow_state_bnds, fast_1_state_bnds, fast_3_state_bnds,
                     slow_input_bnds, fast_1_input_bnds, fast_3_input_bnds,
                     input_guess, input_bnds_lo, input_bnds_up,
                     input_bin_bnds,
                     distb_bnds,
                     output_guess, output_bnds_lo, output_bnds_up,
                     fast_state_setpnts, fast_input_setpnts,
                     alpha, alpha_xf, alpha_uf, pred_horzn,
                     pf, pmg, pse, X_int_f1, X_int_f3,
                     i, del_Uc, Uc_0, Uz_0,
                     alpha_ue, Uc_nb, It_d_Uc, Uc_it_sq, tag,
                     pre_d_Uc, pre_sq_Uc, pre_sq_Uz,
                     pcm, ppn, rxi, rxi_as, yeb):
    w = []
    w0 = []
    lbw = []
    ubw = []
    J = 0
    g = []
    lbg = []
    ubg = []
    
    # "Lift" initial states conditions
    Xk = initial_state
    
    # Other fast states initial conditions
    Xk_f1 = X_int_f1
    Xk_f3 = X_int_f3
    
    # Have last time inputs values
    past_u = Uc_0
    past_uz = Uz_0[1]
    
    del_bns_u = del_Uc*Delta_f
    
    for k in range(1, pred_horzn+1):
        # Define binary for inputs
        Uzk = input_bin_bnds[k-1,:]
        Zu_f = Uzk[1]
        
        # Slow subsystem states and inputs
        Xsk = slow_state_bnds[k-1,:]
        Usk = slow_input_bnds[k-1,:]
        
        # Create input variables
        uname = 'U' + str(k-1)
        Uk = MX.sym(uname, Nuc_f2)
        w += [Uk]
        lbw +=[Zu_f*input_bnds_lo]
        ubw +=[Zu_f*input_bnds_up]
        w0 += [input_guess[k-1,:]]
        
        # Other fast inputs prediction
        Uk_f1 = fast_1_input_bnds[k-1,:]
        Uk_f3 = fast_3_input_bnds[k-1,:]
        
        # Create disturbance
        Dk = distb_bnds[k-1,:]
        
        # Simulate the model
        Ik = I_ode_f2(x0 = Xk, p = vertcat(Xsk, Xk_f1, Xk_f3,
                                           Usk, Uk_f1, Uk, Uk_f3,
                                           Uzk, Dk))
        X_int = Ik['xf']
        
        # Create new states variables
        xname = 'X' + str(k)
        Xk = MX.sym(xname, Nx_f2)
        w += [Xk]
        lbw += [state_bnds_lo]
        ubw += [state_bnds_up]
        w0 += [state_guess[k-1,:]]
        
        # Add dynamic constraints
        g += [X_int - Xk]
        lbg += [[0]*Nx_f2]
        ubg += [[0]*Nx_f2]
        
        # Other fast states prediction
        Xk_f1 = fast_1_state_bnds[k-1,:]
        Xk_f3 = fast_3_state_bnds[k-1,:]
        
        # Create outputs variables
        yname = 'y' + str(k)
        Yk = MX.sym(yname, Ny_f2)
        w += [Yk]
        lbw += [output_bnds_lo]
        ubw += [output_bnds_up]
        w0 += [output_guess[k-1,:]]
        
        # Add ouput constraints
        g += [Yk - out_ies_f2(Xk, Xsk, Xk_f1, Xk_f3,
                              Usk, Uk_f1, Uk, Uk_f3,
                              Uzk, Dk)]
        lbg += [[0]*Ny_f2]
        ubg += [[0]*Ny_f2]
        
        # Inputs' constraints ***********************************
        # Add slack variables
        ename = 'e' + str(k)
        ek = MX.sym(ename, 1)
        w += [ek]
        lbw += [[-np.inf]*1]
        ubw += [[np.inf]*1]
        w0 += [[0]*1]
        
        ue = ek
        
        # increment constraints on inputs
        # between two time instants
        # u(k) - u(k-1)
        if i != 0 or k != 1:
            del_Zu = fabs(Zu_f - past_uz)
            
            g += [Uk - past_u]
            lbg += [-1*(del_bns_u + del_Zu*1e5)]
            ubg += [del_bns_u + del_Zu*1e5]
            pass
        
        # ensure inputs within optimal values' neighbourhood
        # at a time instant
        # u(k) - u(k)*
        g += [Uk - fast_input_setpnts[k-1,:] + ue]
        lbg += [-1*Uc_nb]
        ubg += [Uc_nb]
        
        # ensure fast inputs' increments between two iterations 
        # at the same time instant within acceptable range
        # u(k)^{c} - u(k)^{c-1}
        if tag != 1:
            g += [Uk - Uc_it_sq[k-1,:]]
            lbg += [-1*It_d_Uc]
            ubg += [It_d_Uc]
            pass
        
        # ensure fast inputs' increments between prediction at last time
        # instant and actual value at the current instant within acceptable range
        # u(k) - u(k/k-1)
        if k != pred_horzn:
            pre_Zu_f = pre_sq_Uz[k-1,1]
            pre_del_Zu = fabs(Zu_f - pre_Zu_f)
            
            g += [Uk - pre_sq_Uc[k-1,:]]
            lbg += [-1*(pre_d_Uc + pre_del_Zu*1e5)]
            ubg += [pre_d_Uc + pre_del_Zu*1e5]
            pass
        # *******************************************************
        
        # Sell/Compensation/Penalty Price, Regulation factor
        pse_k = pse[k-1]
        pcm_k = pcm[k-1]
        ppn_k = ppn[k-1]
        rxi_k = rxi[k-1]
        rxi_as_k = rxi_as[k-1]
        yeb_k = yeb[k-1]
        
        # Creat tracking devation term for fast states and inputs
        delta_Xf = Xk - fast_state_setpnts[k-1,:]
        delta_Uf = Uk - fast_input_setpnts[k-1,:]
        
        # The cost function
        J += alpha[0] * (Yk - (1 + rxi_k)*yeb_k)**2
        J += alpha[1] * (pf*(Uk_f1[0] + Uk) + ppn_k*(Yk - (1 + rxi_k)*yeb_k)**2
                         - pmg*Dk[2] - pse_k*Yk - pcm_k*rxi_as_k*yeb_k)
        J += mtimes(mtimes(delta_Xf.T, alpha_xf), delta_Xf)
        J += mtimes(mtimes(delta_Uf.T, alpha_uf), delta_Uf)
        J += mtimes(mtimes(ue.T, alpha_ue), ue)
        
        past_u = Uk
        past_uz = Zu_f
        
        pass
    
    # Concatenate decision variables and constraint terms
    w = vertcat(*w)
    lbw = vertcat(*lbw)
    ubw = vertcat(*ubw)
    w0 = vertcat(*w0)
    g = vertcat(*g)
    lbg = vertcat(*lbg)
    ubg = vertcat(*ubg)
    
    return w, lbw, ubw, w0, g, lbg, ubg, J

# In[2] Create Fast EMPC 2
def solve_opt_f2(w, lbw, ubw, w0, g, lbg, ubg, J):
    print('Creat an F-EMPC 2 solver')
    nlp_prob = {'f': J, 'x': w, 'g': g}
    nlp_solver = nlpsol('nlp_solver', 'ipopt', nlp_prob)
    # Solve NLP
    print('Solve F-EMPC 2')
    sol = nlp_solver(x0=w0, lbx=lbw, ubx=ubw, lbg=lbg, ubg=ubg)
    print(nlp_solver.stats())
    optimalValues = sol['x'].full().ravel()
    
    JValue = sol['f'].full().ravel()
    
    return optimalValues, JValue
